import random
from collections import defaultdict

import numpy as np
import torch
from datasets import Dataset, load_dataset
from tqdm import tqdm

from .Base import BaseDataset, UnlearnDataset


class Pile(BaseDataset):
    def __init__(self, dataset_name, seed=42, ratio=0.001):
        self.dataset_name = dataset_name
        self.dataset = defaultdict()
        self.dataset = self.get_dataset()
        self.seed = seed
        self.ratio = ratio

    def get_dataset(self):
        dataset = load_dataset("monology/pile-uncopyrighted", cache_dir="./.cache")[
            "train"
        ]

        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        dataset = dataset.train_test_split(test_size=self.ratio)

        dataset["train"] = dataset["test"]
        dataset["test"] = load_dataset(
            "monology/pile-uncopyrighted", cache_dir="./.cache"
        )["validation"]

        return dataset

    def __preprocess__(self, tokenizer):
        def preprocess(examples):
            results = {"input_ids": [], "attention_mask": [], "label": []}
            tokenized = tokenizer(
                examples["text"],
                truncation=True,
                padding=True,
                add_special_tokens=True,
                max_length=1024,
            )
            results["input_ids"] = tokenized.input_ids
            results["attention_mask"] = tokenized.attention_mask
            results["label"] = tokenized.input_ids
            return results

        train_dataset = self.dataset["train"].map(
            preprocess, batched=True, remove_columns=["text"]
        )
        test_dataset = self.dataset["test"].map(
            preprocess, batched=True, remove_columns=["text"]
        )

        train_dataset.set_format(
            type="torch", columns=["input_ids", "attention_mask", "label"]
        )

        test_dataset.set_format(
            type="torch", columns=["input_ids", "attention_mask", "label"]
        )

        self.dataset["train"] = train_dataset
        self.dataset["test"] = test_dataset
        return self.dataset

    def build_dataset(self, tokenizer):
        self.__preprocess__(tokenizer)
        return self.dataset
